from __future__ import annotations
from typing import Any, Dict, Tuple
import numpy as np

# Key: Ensure patches are applied & use the same loading functions
import patches  # Importing applies patches
from patches import load_models, load_configspace_with_patches

class SurrogateEvaluator:
    """
    Unified wrapper for NB301 surrogate: predict(config, representation='configspace')->(score, runtime)
    Currently runs DARTS v1.0 (XGB + LGB runtime). FBNet can be added later: can branch with bench='fbnet'.
    """
    def __init__(self, bench: str = "darts-xgb-v1.0"):
        self.bench = bench.lower()
        # Currently using DARTS v1.0
        self.perf_model, self.runtime_model = load_models(version="1.0")
        self.cs = load_configspace_with_patches()  # DARTS ConfigSpace

    @property
    def configspace(self):
        return self.cs

    def predict(self, config: Dict[str, Any], representation: str = "configspace") -> Tuple[float, float]:
        # Performance (NB301 return values have slightly different semantics on different models, use directly as "value to minimize")
        y = self.perf_model.predict(config=config, representation=representation, with_noise=True)
        score = float(y)

        # Runtime (if not available, set to 0)
        rt = self.runtime_model.predict(config=config, representation=representation) if self.runtime_model else 0.0
        runtime = float(rt)
        return score, runtime
